{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ChannelMutator\n",
    "A channel mutator is a manager of the channel structure of a model. In other words, it manages all MutableChannelUnits of a model.  \n",
    "ChannelMutator is the simplest channel mutator. All other channel mutators should inherit from ChannelMutator class. We take ChannelMutator as an example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  How to Construct a ChannelMutator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Suppose we have a model archtecture defineed below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liukai/miniconda3/envs/lab2max/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# define a model\n",
    "from mmengine.model import BaseModel\n",
    "from torch import nn\n",
    "import torch\n",
    "from collections import OrderedDict\n",
    "\n",
    "class MyModel(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            OrderedDict([('conv0', nn.Conv2d(3, 8, 3, 1, 1)),\n",
    "                         ('relu', nn.ReLU()),\n",
    "                         ('conv1', nn.Conv2d(8, 16, 3, 1, 1))]))\n",
    "        self.pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.head = nn.Linear(16, 1000)\n",
    "\n",
    "    def forward(self, x):\n",
    "        feature = self.net(x)\n",
    "        pool = self.pool(feature).flatten(1)\n",
    "        return self.head(pool)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two steps to fully constructing a ChannelMutator object as below. \n",
    "1. we need to initialize a ChannelMutator object.\n",
    "2. Then we need to init the ChannelMutator object with a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11/14 14:24:13 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a input before net.conv0(net.conv0), error: net.conv0(net.conv0)\n",
      "11/14 14:24:13 - mmengine - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - add a output after head(head), error: head(head)\n",
      "The mutator has 2 mutable channel units.\n"
     ]
    }
   ],
   "source": [
    "from mmrazor.models.mutators import ChannelMutator\n",
    "\n",
    "model = MyModel()\n",
    "# initialize a ChannelMutator object\n",
    "mutator = ChannelMutator(\n",
    "    channel_unit_cfg=dict(\n",
    "        type='SequentialMutableChannelUnit',\n",
    "        default_args=dict(choice_mode='ratio'),\n",
    "        units={},\n",
    "    ),\n",
    "    parse_cfg=dict(\n",
    "        type='ChannelAnalyzer'))\n",
    "# init the ChannelMutator object with a model\n",
    "mutator.prepare_from_supernet(model)\n",
    "print(f'The mutator has {len(mutator.mutable_units)} mutable channel units.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ChannelMutator has two arguments:\n",
    "1. channel_unit_cfg: config of the MutableChannelUnit to use in the ChannelMutator.\n",
    "2. parse_cfg: the way to parse the model and get MutableChannelUnits."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are there ways to parse model and get MutableChannelUnits.\n",
    "1. Use a tracer to get MutableChannelUnits automatically.\n",
    "2. Use config dicts to indicate MutableChannelUnits.\n",
    "3. Predefine MutableChannels in the model archtecture.\n",
    "   \n",
    "The example of method 1 has been post above. We post the examples of method 2 and method 3 below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mutator has 2 mutable channel units.\n"
     ]
    }
   ],
   "source": [
    "# 2. use config dicts to indicate MutableChannelUnits.\n",
    "from mmrazor.models.mutators import ChannelMutator\n",
    "\n",
    "model = MyModel()\n",
    "# initialize a ChannelMutator object\n",
    "mutator = ChannelMutator(\n",
    "    channel_unit_cfg=dict(\n",
    "        type='SequentialMutableChannelUnit',\n",
    "        default_args=dict(choice_mode='ratio'),\n",
    "        units={\n",
    "            'net.conv0_(0, 8)_8': {\n",
    "                'init_args': {\n",
    "                    'num_channels': 8,\n",
    "                },\n",
    "                'channels': {\n",
    "                    'input_related': [{\n",
    "                        'name': 'net.conv1',\n",
    "                    }],\n",
    "                    'output_related': [{\n",
    "                        'name': 'net.conv0',\n",
    "                    }]\n",
    "                },\n",
    "                'choice': 1.0\n",
    "            },\n",
    "            'net.conv1_(0, 16)_16': {\n",
    "                'init_args': {\n",
    "                    'num_channels': 16,\n",
    "                },\n",
    "                'channels': {\n",
    "                    'input_related': [{\n",
    "                        'name': 'head',\n",
    "                    }],\n",
    "                    'output_related': [{\n",
    "                        'name': 'net.conv1',\n",
    "                    }]\n",
    "                },\n",
    "                'choice': 1.0\n",
    "            }\n",
    "        }),\n",
    "    parse_cfg=dict(type='Config'))\n",
    "# init the ChannelMutator object with a model\n",
    "mutator.prepare_from_supernet(model)\n",
    "print(f'The mutator has {len(mutator.mutable_units)} mutable channel units.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mutator has 2 mutable channel units.\n"
     ]
    }
   ],
   "source": [
    "# 3. Predefine MutableChannels in the model archtecture.\n",
    "\n",
    "from mmrazor.models.architectures.dynamic_ops import DynamicConv2d, DynamicLinear\n",
    "from mmrazor.models.mutables import MutableChannelUnit, MutableChannelContainer, SquentialMutableChannel\n",
    "from collections import OrderedDict\n",
    "\n",
    "class MyDynamicModel(BaseModel):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__(None, None)\n",
    "        self.net = nn.Sequential(\n",
    "            OrderedDict([('conv0', DynamicConv2d(3, 8, 3, 1, 1)),\n",
    "                         ('relu', nn.ReLU()),\n",
    "                         ('conv1', DynamicConv2d(8, 16, 3, 1, 1))]))\n",
    "        self.pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.head = DynamicLinear(16, 1000)\n",
    "\n",
    "        # register MutableChannelContainer\n",
    "        MutableChannelUnit._register_channel_container(\n",
    "            self, MutableChannelContainer)\n",
    "        self._register_mutables()\n",
    "\n",
    "    def forward(self, x):\n",
    "        feature = self.net(x)\n",
    "        pool = self.pool(feature).flatten(1)\n",
    "        return self.head(pool)\n",
    "\n",
    "    def _register_mutables(self):\n",
    "        mutable1 = SquentialMutableChannel(8)\n",
    "        mutable2 = SquentialMutableChannel(16)\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.net.conv0, mutable1, is_to_output_channel=True)\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.net.conv1, mutable1, is_to_output_channel=False)\n",
    "\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.net.conv1, mutable2, is_to_output_channel=True)\n",
    "        MutableChannelContainer.register_mutable_channel_to_module(\n",
    "            self.head, mutable2, is_to_output_channel=False)\n",
    "\n",
    "\n",
    "model = MyDynamicModel()\n",
    "# initialize a ChannelMutator object\n",
    "mutator = ChannelMutator(\n",
    "    channel_unit_cfg=dict(\n",
    "        type='SequentialMutableChannelUnit',\n",
    "        default_args=dict(choice_mode='ratio'),\n",
    "        units={},\n",
    "    ),\n",
    "    parse_cfg=dict(type='Predefined'))\n",
    "# init the ChannelMutator object with a model\n",
    "mutator.prepare_from_supernet(model)\n",
    "print(f'The mutator has {len(mutator.mutable_units)} mutable channel units.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to Change the Structure of a Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The structure of a model is represented by a dict where the key is the name of a MutableChannelUnit and the value is a structure choice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 8, 1: 16}\n"
     ]
    }
   ],
   "source": [
    "print(mutator.current_choices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can change the dict to prune the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyDynamicModel(\n",
      "  (data_preprocessor): BaseDataPreprocessor()\n",
      "  (net): Sequential(\n",
      "    (conv0): DynamicConv2d(\n",
      "      3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
      "      (mutable_attrs): ModuleDict(\n",
      "        (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)\n",
      "        (out_channels): MutableChannelContainer(num_channels=8, activated_channels=4)\n",
      "      )\n",
      "    )\n",
      "    (relu): ReLU()\n",
      "    (conv1): DynamicConv2d(\n",
      "      8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)\n",
      "      (mutable_attrs): ModuleDict(\n",
      "        (in_channels): MutableChannelContainer(num_channels=8, activated_channels=4)\n",
      "        (out_channels): MutableChannelContainer(num_channels=16, activated_channels=8)\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (pool): AdaptiveAvgPool2d(output_size=1)\n",
      "  (head): DynamicLinear(\n",
      "    in_features=16, out_features=1000, bias=True\n",
      "    (mutable_attrs): ModuleDict(\n",
      "      (in_features): MutableChannelContainer(num_channels=16, activated_channels=8)\n",
      "      (out_features): MutableChannelContainer(num_channels=1000, activated_channels=1000)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "mutator.set_choices(\n",
    "    {0: 4, 1: 8}\n",
    ")\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please refer to our documents for more choices related methods."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('lab2max')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e31a827d0913016ad78e01c7b97f787f4b9e53102dd62d238e8548bcd97ff875"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
